#include "libPJutil/biniostream.h"
#include "blitz/tinyvec.h"
#include "blitz/array.h"
#include <iostream>
#include "vecops.h"
#include "cutil.h"
#include "mcrx-debug.h"
#include "thermal_equilibrium_grain.h"
#include "wd01_grain_model.h"
#include "wd01_Brent_PAH_grain_model.h"
#include "grain_model_hack.h"
#include "counter.h"
#include "mcrx-units.h"
#include "grain_size.h"
#include "CCfits/CCfits"
#include "fits-utilities.h"
#include "blitz-fits.h"
#include "shrink_expand.h"
#include "boost/thread/thread.hpp"
#include "blackbody.h"
#include "boost/program_options.hpp"
#include "cuda_grain_temp.h"
#include "preferences.h"

namespace po = boost::program_options;

typedef float T_float;

using namespace std;
using namespace blitz;
using mcrx::array_1;
using mcrx::array_2;
typedef blitz::Array<float, 1> farray_1;
typedef blitz::Array<float, 2> farray_2;

using mcrx::thermal_equilibrium_grain;


int main(int argc, char** argv)
{
  // Declare the supported options.
  po::options_description desc("CUDA temp calc benchmark. Allowed options");
  desc.add_options()
    ("output_file", po::value<string>(), "output file name")
    ("intensity_file", po::value<string>()->default_value("/data/patrik/test/Sbc11i4-u4/test/mcrx_010.fits"), "input file name")
    ("grid_file", po::value<string>()->default_value("/data/patrik/sims/Sbc11i4-u4/set5as/grid_010.fits"), "input file name")
    ("grain_dir", po::value<string>()->default_value("/home/patrik/dust_data/crosssections"), "directory containing grain cross sections files")
    ("grain_file", po::value<string>()->default_value("graphite.fits"), "name of grain cross section file")
    ("grain_set", po::value<string>()->default_value("DL07_MW3.1_60"), "name of WD01 grain model set")
    ("n_cells", po::value<size_t>(), "number of cells")
    ("n_threads", po::value<int>()->default_value(8), "number of threads")
    ("start_cell", po::value<size_t>()->default_value(0), "cell to start with")
    ("block_size", po::value<size_t>()->default_value(0), "cell blocksize")
    ("check_all", po::value<bool>()->default_value(false), "check temp and heating in addition to SED")
    ("use_lookup", po::value<bool>()->default_value(false), "use a lookup table for CPU instead of solving for T")
    ("use_grain_model", po::value<bool>()->default_value(false), "do calculation for a full grain model as opposed to a single grain object")
    ("run_cpu_calculation", po::value<bool>()->default_value(true), "run the CPU calculation")
    ("cuda_device", po::value<int>()->default_value(0), "CUDA device to use")
    ("help", "print this page")
    ;

  po::positional_options_description p;
  p.add("output_file", 1);

  po::variables_map opt;
  po::store(po::command_line_parser(argc, argv).
          options(desc).positional(p).run(), opt);
  po::notify(opt);
  if (opt.count("help")) {
    cout << desc << "\n";
    return 1;
  }
  if(!opt.count("output_file")) {
    cout << "Must specify output file name" << endl;
    return 1;
  }
  if(!opt.count("n_cells")) {
    cout << "Must specify number of cells" << endl;
    return 1;
  }

  mcrxcuda::cuda_init(opt["cuda_device"].as<int>());

  const size_t startc=opt["start_cell"].as<size_t>();
  size_t nc=opt["n_cells"].as<size_t>();
  const bool check_all=opt["check_all"].as<bool>();
  // if using check_all, make sure nc is multiple of the blocksize (16)...
  if(check_all && (nc%16)) {
    cerr << "Error: When using check_all, n_cells must be a multiple of 16\n";
    exit(1);
  }

  // load intensity and m_dust from a fits file.
  CCfits::FITS datafile(opt["intensity_file"].as<string>());
  // load intensity
  array_2 intensity;
  read(open_HDU(datafile, "INTENSITY"), intensity);
  if((nc==0)||(nc>intensity.extent(firstDim)))
    nc=intensity.extent(firstDim);
  cout << "Intensity data contains " << intensity.shape() << endl;
  intensity.reference(intensity(Range(startc,startc+nc-1), Range::all()));
  intensity = pow(10.0,intensity);
  intensity.makeUnique();
  assert(intensity.isStorageContiguous());
  cout << "Max intensity at " << maxIndex(intensity) << '\t' << max(intensity) <<endl;
  
  ExtHDU& lambda_hdu = open_HDU(datafile, "LAMBDA");
  array_1 lambda,lambda_int;
  read(lambda_hdu.column("lambda"), lambda);
  cout << "Lambda data contains " << lambda.shape() << endl;
  read(lambda_hdu.column("lambda_intensity"), lambda_int);
  int n_lambda_int;
  lambda_hdu.readKey("NLAMBDA_INTENSITY", n_lambda_int);
  lambda_int.resizeAndPreserve(n_lambda_int);

  // load m_dust
  CCfits::FITS gridfile(opt["grid_file"].as<string>());
  array_1 m_dust;
  Column& c_m_metals = open_HDU(gridfile, "GRIDDATA").column("mass_metals");
  read(c_m_metals, m_dust);
  cout << "Dust mass data contains " << m_dust.shape() << endl;
  
  const float x=0.4;//*units::convert("sunmass","kg");
  m_dust.reference(m_dust(Range(startc,startc+nc-1)));
  m_dust.makeUnique();
  m_dust=x*m_dust;
  assert(m_dust.isStorageContiguous());

  //load grain model
  Preferences prefs;
  prefs.setValue("grain_data_directory", opt["grain_dir"].as<string>());
  prefs.setValue("wd01_parameter_set", opt["grain_set"].as<string>());
  prefs.setValue("use_dl07_opacities", true); 
  prefs.setValue("use_cuda", true);
  prefs.setValue("n_threads", opt["n_threads"].as<int>());
  prefs.setValue("template_pah_fraction", 0.5);

  mcrx::T_unit_map units;
  units["wavelength"]="m";
  units["length"]="m";
  units["luminosity"]="W";
  units["mass"]="Msun";
  units["time"]="yr";

  mcrx::wd01_Brent_PAH_grain_model<mcrx::polychromatic_scatterer_policy, 
    mcrx::local_random> gm_cuda(prefs, units);
  gm_cuda.resample(lambda_int, lambda);
  prefs.setValue("use_cuda", false);
  mcrx::wd01_Brent_PAH_grain_model<mcrx::polychromatic_scatterer_policy, 
    mcrx::local_random> gm_cpu(prefs, units);
  gm_cpu.resample(lambda_int, lambda);

  mcrx::thermal_equilibrium_grain g
    (opt["grain_dir"].as<string>()+"/"+opt["grain_file"].as<string>(), 
     Preferences());

  g.use_lookup_ = opt["use_lookup"].as<bool>();
  const T_float accuracy = 1e-3;
  g.accuracy_ = accuracy;

  g.resample(lambda_int,lambda);
  // size distribution
  mcrx::wd01_graphite_distribution size_distribution
    (opt["grain_set"].as<string>(), units);
  // number of grains in the size bin
  array_1 dn(size_distribution.dn_da(g.asizes())*
	      g.delta_size());

  const size_t nl=lambda.size();
  const size_t ns=g.sigma().extent(firstDim);

  // arrays are intensity(cell, lambda), sigma(size, lambda),
  // heating(size, cell).
  cout << "Intensity size " << intensity.shape() << endl;
  cout << "Sigma size " << g.sigma().shape() << endl;
  cout << "Lambda size " << lambda.shape() << endl;
  cout << "m_dust size " << m_dust.shape() << endl;


  // return values (heating and temp are transposed compared to our
  // convention to get by max pitch when allocating cuda arrays, ie #
  // of cells must be first dim)
  array_2 heating;
  array_2 temp;
  if(check_all) {
    heating.resize(ns,nc); 
    temp.resize(ns,nc);
    heating=1e30;
    temp=1e30;
  }
  array_2 sed(nc,nl);
  sed=0;

  // some sanity checks
  assert(intensity.isStorageContiguous());
  assert(m_dust.isStorageContiguous());
  assert(dn.isStorageContiguous());
  assert(sed.isStorageContiguous());
  assert(g.sigma().isMinorRank(secondDim));
  assert(intensity.isMinorRank(secondDim));
  assert(sed.isMinorRank(secondDim));
  assert(sed.extent(firstDim)==nc);
  assert(sed.extent(secondDim)==nl);
  assert(g.sigma().extent(firstDim)==dn.size());
  assert(dn.size()==ns);

  unsigned int timer=0, gputime=0;
  cutCreateTimer(&gputime);
  cutStartTimer(gputime);

  if(opt["use_grain_model"].as<bool>())
    gm_cuda.calculate_SED(intensity, m_dust, sed);
  else
    g.CUDA_calculate_SED_from_intensity(intensity,
					m_dust,
					dn, sed, true,
					heating, temp);
  
  cutStopTimer(gputime);

  // gpu calculation has a missing 4pi in the heating values, so we
  // add those here for the comparison
  heating*= 4*constants::pi;

  if(check_all) {
    ASSERT_ALL(heating==heating);
    ASSERT_ALL(heating>=0);
    ASSERT_ALL(heating<1e30);
    ASSERT_ALL(temp==temp);
    ASSERT_ALL(temp>=0);
    ASSERT_ALL(temp<2000);
  }
  ASSERT_ALL(sed==sed);
  ASSERT_ALL(sed>=0);
  ASSERT_ALL(sed<1e100);

  if(!opt["run_cpu_calculation"].as<bool>())
    exit(0);

  // we should now have the result in heating. test it

  array_2 heating_standard;
  array_2 temp_standard;
  if(check_all) {
    heating_standard.resize(ns,nc); 
    temp_standard.resize(ns,nc);
    heating_standard=1e30;
    temp_standard=1e30;
  }
  array_2 sed_standard(nc,nl);
  sed_standard=0;

  cout << "Calculating comparison result" << endl;

  cutCreateTimer(&timer);
  cutStartTimer(timer);

  counter cc(1000);

  if(opt["use_grain_model"].as<bool>())  
    gm_cpu.calculate_SED(intensity, m_dust, sed_standard);
  else
    g.start_threads(intensity, m_dust, dn,
		    sed_standard, 
		    true,
		    opt["n_threads"].as<int>(), 
		    opt["block_size"].as<size_t>(), 
		    check_all ? &heating_standard: 0, 
		    check_all ? &temp_standard : 0);

  cutStopTimer(timer);

  ASSERT_ALL(heating_standard==heating_standard);
  ASSERT_ALL(heating_standard>=0);
  ASSERT_ALL(temp_standard==temp_standard);
  ASSERT_ALL(temp_standard>=0);
  ASSERT_ALL(temp_standard<2000);
  ASSERT_ALL(sed_standard==sed_standard);
  ASSERT_ALL(sed_standard<1e40);


  printf("Processing time: %f (ms) \n", cutGetTimerValue(timer));

  cout << "\nGPU net acceleration: " 
       << cutGetTimerValue(timer)/cutGetTimerValue(gputime) << "x\n";
  cutDeleteTimer(timer);
  cutDeleteTimer(gputime);
  
  if(check_all) {
    cout << "\nMean CPU heating " << mean(heating_standard)<< endl;
    cout << "Mean GPU heating " << mean(heating)<< endl;
    cout << "min/max GPU " << min(heating) << ' ' << max(heating) << endl;
    cout << "RMS fractional difference of heating: " 
	 << sqrt(mean((heating-heating_standard)
		      *(heating-heating_standard)))/
      mean(heating_standard)
	 << endl;
    cout << "Max fractional difference of heating: " 
	 << max(abs(heating-heating_standard)/
		heating_standard)
	 << endl;

    cout << "\nMean CPU grain temp " << mean(temp_standard)<< endl;
    cout << "minmax CPU: " << min(temp_standard) << ' ' << max(temp_standard) << endl;
    cout << "Mean GPU grain temp " << mean(temp)<< endl;
    cout << "minmax GPU: " << min(temp) << ' ' << max(temp) << endl;
    cout << "RMS fractional difference of temps: " 
	 << sqrt(mean((temp-temp_standard)*(temp-temp_standard)))/
      mean(temp_standard)
	 << endl;
    cout << "Max fractional difference of temps: " 
	 << max(abs(temp-temp_standard)/temp_standard)
	 << endl;
  }
  else {
    cout << "To compare heating and temperatures, use check_all=1\n";
  }

  cout << "\nMean CPU sed: " << mean(sed_standard)<< endl;
  cout << "minmax CPU: " << min(sed_standard) << ' ' << max(sed_standard) << endl;
  cout << "Mean GPU sed: " << mean(cast<double>(sed))<< endl;
  cout << "minmax GPU: " << min(sed) << ' ' << max(sed) << endl;
  const double sedrms =  
    sqrt(mean((sed-sed_standard)*(sed-sed_standard)))/mean(sed_standard);
  cout << "RMS fractional difference of seds: " << sedrms
       << endl;
  cout << "Max absolute difference of seds (ignoring float underflows): " 
       << max(abs(sed-sed_standard))
       << endl;

  if(sedrms>1e-3)
    cerr << "\nTEST FAILED!\n" << endl;
  else
    cout << "\nPassed\n" << endl;

  cout << "Dumping data." << endl;

  ofstream of(opt["output_file"].as<string>().c_str());
  of << nc << '\t' << ns << '\t' << nl << '\n';
  cout << heating.shape() << heating_standard.shape() << endl;
  cout << heating.stride() << heating_standard.stride() << endl;
  cout << heating.ordering() << heating_standard.ordering() << endl;
  cout << temp.shape() << temp_standard.shape() << endl;
  cout << temp.stride() << temp_standard.stride() << endl;
  //assert(heating.isStorageContiguous());
  //assert(heating_standard.isStorageContiguous());
  of.write(reinterpret_cast<char*>(lambda.data()),sizeof(double)*nl);
  of.write(reinterpret_cast<const char*>(g.asizes().data()),sizeof(double)*ns);

  of.write(reinterpret_cast<char*>(heating.data()),sizeof(double)*nc*ns);
  of.write(reinterpret_cast<char*>(heating_standard.data()),sizeof(double)*nc*ns);

  of.write(reinterpret_cast<char*>(temp.data()),sizeof(double)*nc*ns);
  of.write(reinterpret_cast<char*>(temp_standard.data()),sizeof(double)*nc*ns);

  of.write(reinterpret_cast<char*>(sed.data()),sizeof(double)*nc*nl);
  of.write(reinterpret_cast<char*>(sed_standard.data()),sizeof(double)*nc*nl);

  of.close();

  /*
  ofstream o("crap");
  for(int i=0;i<nc;++i)
    for(int j=0;j<ns;++j)
      o << i << ' ' << j << ' ' << heating(i,j) << ' ' << heating_standard(i,j) << '\n';
  */
}

//previous calculation, now in thermal_equilibrium_grain

/*
class temp_thread {
public:
  static const int cache_line_size = 128;

  /// To ensure threads don't share cache line.
  char padding [cache_line_size]; 
  bool use_lookup_;

  thermal_equilibrium_grain g_;

  const array_2& intensity_;
  const array_1& dnd_;
  const array_1& m_dust_;
  
  array_2 heating_;
  array_2 temp_;
  array_2 sed_;

  /// list of cell blocks to be processed
  vector<pair<size_t,size_t> >& clist_;

  /// Reference to the mutex protecting the cell vector
  boost::mutex& cell_mutex_;


public:
  temp_thread(const thermal_equilibrium_grain& g,
	      const array_2& i, const array_1& d,
	      const array_1& md,
	      array_2& h,
	      array_2& t,
	      array_2& s,
	      vector<pair<size_t,size_t> >& cl,
	      boost::mutex& cm,
	      bool ul) :
    g_(g),
    intensity_(i),
    dnd_(d),
    m_dust_(md),
    clist_(cl), cell_mutex_(cm), use_lookup_(ul)
  {
    //intensity_.weakReference(i);
    //dnd_.weakReference(d);
    //m_dust_.weakReference(md);
    heating_.weakReference(h);
    temp_.weakReference(t);
    sed_.weakReference(s);

  };

  temp_thread(const temp_thread& t) :
    g_(t.g_),
    intensity_(t.intensity_), 
    dnd_(t.dnd_), 
    m_dust_(t.m_dust_), heating_(t.heating_), temp_(t.temp_), sed_(t.sed_),
    clist_(t.clist_), cell_mutex_(t.cell_mutex_), use_lookup_(t.use_lookup_)
  {};

  void operator () ();
  void run_range(size_t cmin, size_t cmax);
};

void startem(const thermal_equilibrium_grain& g,
	     const array_2& intensity, const array_1& dnd,
	     const array_1& m_dust, array_2& heating,
	     array_2& temp, array_2& sed,
	     int n_threads, size_t block_size, bool use_lookup)
{
  const int nc=intensity.extent(firstDim);
  if(block_size==0)
    block_size=size_t(ceil(1.0*nc/n_threads));
  cout << "Block size is " << block_size << endl;

  // make list
  vector<pair<size_t, size_t> > clist;
  for (size_t c=0; c<nc; c+=block_size) {
    size_t cend= c+block_size-1;
    if (cend>=nc)
      cend=nc-1;
    //cout << "Adding range " << c << "-" << cend << " to list" << endl;
    clist.push_back(make_pair(c,cend));
  }
  //cout << "Block list has " << clist.size() << " entries" << endl;

  // create threads
  boost::thread_group threads;
  std::vector<boost::shared_ptr<temp_thread> > thread_objects;
  boost::mutex cell_mutex;
  const float nper=nc/n_threads;
  cout << "Running " << nper << " cells per thread." << endl;
  for (int i = 0; i < n_threads; ++i) {
    // make sure we run the last cell
    //const size_t cstart = static_cast<size_t>(floor(nper*i));
    //const size_t cend = static_cast<size_t>(ceil(nper*(i+1)-1));
    //const int cend= (i==n_threads-1) ? nc-1 : cstart+nc/n_threads-1;
    //cout << "\tThread " << i << " cells " << cstart << "-" << cend << endl;
    thread_objects.push_back(boost::shared_ptr<temp_thread> 
			     (new temp_thread
			      (g, intensity, dnd, m_dust,
			       heating, temp, sed,
			       clist, cell_mutex, use_lookup)));
  }
  cout << " Spawning " << n_threads << " threads" << endl;
  for (int i = 0; i < n_threads; ++i) {
    threads.create_thread(*thread_objects[i]);
  }
  // wait for them to die
  threads.join_all();
}

void temp_thread::operator()()
{
  while (true) {
    // pop another cell block off the stack, exit if empty
    size_t cstart, cend;
    {
      // open scope for locking mutex
      boost::mutex::scoped_lock stack_lock (cell_mutex_);
      if (clist_.empty())
	break;
      cstart = clist_.back().first;
      cend = clist_.back().second;
      clist_.pop_back();
    }

  // here we block the calculation into reasonably-sized blocks to
  // avoid cache thrashing
    //for (int cstart = cmin_; cstart < cmax_; cstart+=block_size_) {
    //size_t cend = cstart+block_size_-1;
    // make sure we cut the last block
    //cend = (cend>cmax_) ? cmax_ : cend;
    cout << "\tRunning cells " << cstart << "-" << cend << endl;
    run_range(cstart,cend);
  }
  cout << "\tThread done" << endl;
}
 
void temp_thread::run_range(size_t cmin, size_t cmax)
{
  const int ns=g_.sigma().extent(firstDim);
  const int nl=g_.sigma().extent(secondDim);
  const array_2 dintensity
    (cast<double>(intensity_(Range(cmin,cmax), Range::all())));
  sed_(Range(cmin,cmax),Range::all())=0;

  for(int s=0;s<ns;++s) {
    for(int c=cmin; c<=cmax; ++c) {
      if(use_lookup_) {
	temp_(s,c) = g_.calculate_T(s, dintensity(c-cmin,Range::all()));
	heating_(s,c) = g_.absorption()(s);
      }
      else {
	g_.set_intensity(s,dintensity(c-cmin,Range::all()));
	heating_(s,c) = g_.absorption()(s);

	T_float xx = log10(g_.absorption()(s)) - 
	  1.39*log10(g_.asizes()(s)) + 
	  0.126*pow(log10(g_.asizes()(s)),2);
	T_float Tguess =
	  pow(10,1.86+0.189*xx+3.41e-3*xx*xx);

	  //T_float Tguess =
	  //6.79*pow(g_.absorption()(s), 0.181)*pow(g_.asizes()(s),-3.183*0.181);

	Tguess = (Tguess<5)?5:Tguess;
	
	temp_(s,c)= g_.solve_for_T(s, Tguess);
      }

      assert(temp_(s,c)>=0);
      assert(temp_(s,c)<1e4);

      array_1 sed1(g_.calculate_SED(s,temp_(s,c)));
      sed_(c,Range::all()) +=
	g_.sigma()(s, Range::all())*
	mcrx::B_lambda(g_.invlambda(), 1./temp_(s,c))*
	dnd_(s);
      }
  }

  for(int c=cmin; c<=cmax; ++c) {
    sed_(c,Range::all()) *= 4*constants::pi*m_dust_(c);
  }

}
*/
